from __future__ import division, print_function, absolute_import

import numpy as np
import pickle
from CDTools.tools.cmath import *
import torch as t
from RIP_tools import *
from CDTools.tools.propagators import generate_angular_spectrum_propagator as gasp
from CDTools.tools.propagators import near_field


#
# This section runs the numerical experiment designed
# to test the robustness of RIP to a poorly aligned probe
#

probe_maxk = 128

# Set GPU to false if you don't have a GPU
GPU = True

# THis defines a reconstruction with R=4
resolution = int(probe_maxk * 0.4) * 2

# This defines the size of the array our  probe is defined on
field_size = int(np.ceil(resolution + probe_maxk)) * 2
probe_fov_radius = (field_size * 4)//10


# This generates our test probe
nearfield = generate_blr_probe([field_size,field_size],
                               probe_fov_radius, 
                               probe_maxk)



# Here we define the size of the region which we will check error on
error_check_padding = 1/ 4 
padding = int(resolution * error_check_padding)

losses = []
errors = []


# DOF is wavelength /(2 * NA^2), a.k.a spacing^2/wavelength * 1/2 (N/N_p)
# in units of spacing^2 / wavelength, this becomes:
dof = 1/2 * (field_size/probe_maxk)**2 

# We set up the collection of defocus points
defocuses = np.linspace(-3,3,51)

for propagation_distance in defocuses:
    loop_errors = []
    loop_losses = []
    for attempt in range(200):

        # First, we generate a random image
        real = np.random.randn(resolution, resolution)
        imag = np.random.randn(resolution, resolution)
        original_image = real + 1j * imag
        original_image = complex_to_torch(original_image).to(dtype=t.float32)

        # We upsample the probe to avoid aliasing
        simulation_nearfield = expand_probe(nearfield, original_image)

        # Now we simulate diffraction 
        exit_wave = interact(original_image, simulation_nearfield)
        diffraction = propagate(exit_wave)
        intensities = measure(diffraction, nearfield.shape[:2])
        
        # Then, we generate a phase mask in reciprocal space to use
        # for angular spectrum propagation
        prop = gasp(nearfield.shape[:2],(1,1),1,dof*propagation_distance).to(dtype=t.float32)

        # And perform the propagation to get a misaligned probe guess
        prop_nearfield = near_field(nearfield, prop)

        # Now we perform the retrieval with the misaligned probe guess
        retrieval, loss = reconstruct(intensities, prop_nearfield, resolution, 0.4, 1000, GPU=GPU, schedule=True)
        ret = torch_to_complex(retrieval)
        im = torch_to_complex(original_image)

        # We extract the resulting error
        error, originals, comparisons = compare_images(ret, im, np.s_[padding:-padding,padding:-padding])
        print('Propagation By', propagation_distance, 'DOF Leads To Loss', '%.1e'%loss[-1],'And Error', '%.1e'% error)

        # And save it out
        loop_losses.append(loss[-1])
        loop_errors.append(error)

    losses.append(loop_losses)
    errors.append(loop_errors)

    
results = {'probe': nearfield, 'losses':losses, 'errors':errors, 'defocuses': defocuses}
with open('data/propagation randomim trial.pickle', 'wb') as f:
    pickle.dump(results, f)
